import lightning as L
from typing import Any, List, Optional
import torch
from pathlib import Path
import typer
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from lightning.pytorch.utilities import rank_zero_only
import time
from oh_my_bloom.model import DeepSpeedChatBloom
from oh_my_bloom.tokenizer import get_chatbloom_tokeizer
from oh_my_bloom.datamodule import ChatDataModule
import os

os.environ['TOKENIZERS_PARALLELISM']= "False"
            
def run(plm_path: str, 
        output_dir: str,
        ckpt_convert_dir: str,
        cache_dir: str,
        resume_ckpt_path: Optional[str] = None,
        version: str = 'v1',
        model_max_length: int = 512,
        lr: float = 1e-5,
        num_warmup_steps: int = 100,
        weight_decay: float = 0.0,
        seed: int = 42,  
        num_workers: int= 10,
        batch_size: int = 10,
        max_epochs: int = 3,
        limit_train_batches: Optional[float] = None,
        limit_val_batches: Optional[float] = None,
        devices: int = 6, 
        accelerator: str = 'cuda',
        strategy: Optional[str] = 'deepspeed_stage_2', 
        precision: str = '16-mixed',
        fast_dev_run: bool = False):
    """
    plm_path (str): 预训练模型路径.
    output_dir (str): 日志以及模型参数保存目录.
    ckpt_convert_dir (str): deepspeed类型参数转换后的保存目录.
    version (str, optional): 版本. Defaults to 'v1'.
    model_max_length (int, optional): 模型最大输入长度. Defaults to 512.
    lr (float, optional): 学习率. Defaults to 2e-5.
    num_warmup_steps (int, optional): 预热的步数. Defaults to 100.
    weight_decay (float, optional): 权重衰减. Defaults to 0.0.
    seed (int, optional): 随机数种子. Defaults to 42.
    num_workers (int, optional): dataloader中的进程数. Defaults to 10.
    batch_size (int, optional): 批次大小. Defaults to 10.
    max_epochs (int, optional): 最大训练迭代. Defaults to 3.
    limit_train_batches (Optional[float], optional): 限制训练批次. Defaults to None.
    limit_val_batches (Optional[float], optional): 限制验证批次. Defaults to None.
    devices (int, optional): 设备. Defaults to 6.
    accelerator (str, optional): 加速器. Defaults to 'cuda'.
    strategy (Optional[str], optional): 加速策略. Defaults to 'deepspeed_stage_2'.
    precision (str, optional):训练精度. Defaults to '16-mixed'.
    fast_dev_run (bool, optional): 快速开发模式. Defaults to False.
    """

    L.seed_everything(seed=seed, workers=True)
    
    torch.set_float32_matmul_precision('high')

    tokenizer = get_chatbloom_tokeizer(plm_path=plm_path, model_max_length=model_max_length)

    dm = ChatDataModule(tokenizer=tokenizer, 
                        data_dir='/root/autodl-tmp/OMInstructions',
                        cache_dir=cache_dir,
                        batch_size=batch_size,
                        num_workers=num_workers)

    if strategy.startswith('deepspeed'):
        offload = True if strategy.endswith('offload') else False
        model = DeepSpeedChatBloom(plm_path=plm_path,
                                   chat_tokenizer=tokenizer,
                                   lr=lr, 
                                   weight_decay=weight_decay, 
                                   num_warmup_steps=num_warmup_steps, 
                                   offload=offload)
        
    elif strategy.startswith('fsdp'):
        # 还未完成fsdp训练方式
        pass
    
    
    # 按照日期保存日志和ckpt
    output_timed_dir = Path(output_dir, time.strftime('%Y-%m-%d'))
    if not output_timed_dir.exists():
        output_timed_dir.mkdir(parents=True)

    csv_logger = CSVLogger(output_timed_dir, name='logs', version=version)
    wandb_dir = Path(output_timed_dir, 'wandb_logs')
    if not wandb_dir.exists():
        wandb_dir.mkdir()
    wandb_logger = WandbLogger(project=f'oh-my-bloom', save_dir=wandb_dir, version=version)

    ckpt_dir = Path(output_timed_dir, 'checkpoints')
    model_checkpoint = ModelCheckpoint(dirpath=ckpt_dir, 
                                       filename='epoch={epoch}-step={step}-loss={train/loss:.2f}', 
                                       auto_insert_metric_name=False,
                                       save_last=True)
    lr_monitor = LearningRateMonitor(logging_interval='step')
    
    trainer = L.Trainer(accelerator=accelerator, 
                        devices=devices, 
                        strategy=strategy, 
                        max_epochs=max_epochs,
                        enable_checkpointing=True,
                        precision=precision,
                        fast_dev_run=fast_dev_run,
                        limit_train_batches=limit_train_batches,
                        limit_val_batches=limit_val_batches,
                        callbacks=[model_checkpoint, lr_monitor],
                        logger=[wandb_logger, csv_logger],
                        profiler='simple')
    
    trainer.fit(model=model, datamodule=dm, ckpt_path=resume_ckpt_path)

    # 将deepspeed的参数保存格式转换为lightning可以直接加载的格式
    if strategy.startswith('deepspeed') and not fast_dev_run:
        best_ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
        convert_save_path = Path(ckpt_convert_dir, version, best_ckpt_path.name)
        if not convert_save_path.parent.exists():
            convert_save_path.parent.mkdir()
        rank_zero_only(convert_zero_checkpoint_to_fp32_state_dict)(checkpoint_dir=best_ckpt_path, output_file=convert_save_path)

if __name__ == "__main__":
    typer.run(run)